import matplotlib.pyplot as plt
import numpy as np
import os

'''
    第i张图画不同方法的i阶基（i为1,2,3,4,5,6）
    将6张图放在subplot中，最后6张图保存为一张图
'''

# plt.rc('font', family='Latin Modern Roman')
plt.rc('font', family='Times New Roman')

def plotAirfoilGeom(x, y_list, labels, title_name, show_legend=False):
    '''
        在同一张图上展示一系列翼型
    :param coordinates:
    :param file_path:
    :return:
    '''

    linestyle_list = ['solid',
                      ':', '--', '-.',
                      ':', '--', '-.']
    color_list = ["black", "red", "blue", "red", "blue", "red", "blue"]

    for i in range(y_list.shape[1]):
        plt.plot(x, y_list[:, i],
                 linewidth=1,
                 linestyle=linestyle_list[i],
                 label=labels[i],
                 c=color_list[i])

    font = {'weight': 'normal',
            'size': 16}

    # 标题
    plt.title(title_name, fontdict=font)

    # 横纵坐标标题
    plt.xlabel("X/C", font)
    plt.ylabel("Y/C", font)

    font1 = 13
    # 横纵坐标刻度
    plt.xticks(size=font1)
    plt.yticks(size=font1)

    if show_legend:
        plt.legend(fontsize=font1, bbox_to_anchor=(1, 0.67))

    plt.grid()


if __name__ == "__main__":

    x = np.loadtxt("basis_data/x.dat", delimiter=' ')[:, 0]

    y1 = np.loadtxt("basis_data/POD_basis_NOML.csv", delimiter=',')[:, 0:6]
    y2 = np.loadtxt("basis_data/POD_basis_LGB.csv", delimiter=',')[:, 0:6]
    y3 = np.loadtxt("basis_data/POD_basis_XGB.csv", delimiter=',')[:, 0:6]
    y4 = np.loadtxt("basis_data/POD_basis_GBDT.csv", delimiter=',')[:, 0:6]

    y_list = [y1, y2, y3, y4]
    names = ["1st", "2nd", "3rd", "4th", "5th", "6th"]
    title_names = [r"$1^{st}$ mode", r"$2^{nd}$ mode",
                   r"$3^{rd}$ mode", r"$4^{th}$ mode",
                   r"$5^{th}$ mode", r"$6^{th}$ mode"]
    show_legend_list = [False, False, False, True, False, False]

    long = 12
    height = 11

    fig = plt.figure(figsize=(long, height), dpi=1000)
    fig.canvas.draw()

    # 调整子图间的距离
    plt.subplots_adjust(wspace=0.25, hspace=0.38)

    for i in range(6):  # 1-6阶基
        y_i = np.empty(shape=(301, len(y_list)))
        for j, y in enumerate(y_list):
            y_i[:, j] = y[:, i]

        labels = ["CFD-MPOD", "DMPOD(LGB)", "DMPOD(XGB)", "DMPOD(GBDT)"]
        plt.subplot(3, 2, i+1)
        plotAirfoilGeom(x, y_i, labels, title_names[i], show_legend_list[i])

    save_path = r"basis_data/Basis.png"
    # plt.show()
    plt.savefig(save_path, dpi=1000, bbox_inches='tight')
    plt.close()

    print("二维翼型可视化图已保存在{}".format(save_path))
